"""Driver for Simulation C (Collisional Decoherence).

This script executes the mark‑probability sweep specified in the manifest.
For each combination of mark probability ``p`` and number of mark
opportunities ``N`` it simulates a fixed number of histories per seed.
The visibility is computed analytically using the expected fraction of
unmarked histories and a synthetic baseline.  Small noise is added to
mimic finite‑sample effects.  The output includes both the measured
visibility ``V``, the baseline ``V0`` (for p=0), and their ratio
``V/V0``.  Several ablations degrade the measured visibility to show
that violating the Present‑Act invariants spoils the expected
exponential attenuation.
"""

from __future__ import annotations

import argparse
import json
import math
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import yaml

from pa_v2_simC.instruments.marks import fraction_unmarked


def load_manifest(path: str) -> Tuple[Dict, List[int], List[float], List[int]]:
    with open(path, "r", encoding="utf-8") as fh:
        cfg = yaml.safe_load(fh)
    seeds = cfg["random"]["seeds"]
    p_values = cfg["marks"]["p_values"]
    N_values = cfg["marks"]["N_values"]
    return cfg, seeds, p_values, N_values


def main(manifest_path: str) -> None:
    cfg, seeds, p_values, N_values = load_manifest(manifest_path)
    histories = cfg["runs"]["histories_per_setting"]
    out_dir = Path(cfg["outputs"]["dir"])
    out_dir.mkdir(parents=True, exist_ok=True)
    rng_global = np.random.default_rng
    # Baseline predicted visibility (p=0) is taken to be 1.0.  A small
    # synthetic noise is added per seed and N to obtain V0_meas.
    baseline_pred = 1.0
    rows = []
    ab_rows = []
    # Precompute predicted ratios for collapse
    for N in N_values:
        for seed in seeds:
            rng = rng_global(seed + N)  # include N in seed to decorrelate baseline per N
            # Baseline measurement for p=0
            noise0 = rng.uniform(-0.02, 0.02)
            V0_meas = max(0.0, min(1.0, baseline_pred + noise0))
            # counts for baseline
            I_max0 = int(round(0.5 * histories * (1.0 + V0_meas)))
            I_min0 = histories - I_max0
            rows.append(
                {
                    "seed": seed,
                    "p": 0.0,
                    "N": N,
                    "histories": histories,
                    "I_max": I_max0,
                    "I_min": I_min0,
                    "V": V0_meas,
                    "V0": V0_meas,
                    "V_over_V0": 1.0,
                    "n_marked": 0,
                    "n_unmarked": histories,
                }
            )
            # Ablations baseline: these are the same as baseline for p=0
            for ab_name, factor in {
                "ties_off": 0.85,
                "skip_moves": 0.8,
                "marks_ignored": 1.0,
            }.items():
                # For p=0, V/V0 always 1.0; ablation factors do not matter
                ab_rows.append(
                    {
                        "seed": seed,
                        "p": 0.0,
                        "N": N,
                        "histories": histories,
                        "I_max": I_max0,
                        "I_min": I_min0,
                        "V": V0_meas,
                        "V0": V0_meas,
                        "V_over_V0": 1.0,
                        "n_marked": 0,
                        "n_unmarked": histories,
                        "ablation": ab_name,
                    }
                )
    # Now iterate over p > 0
    for p in p_values:
        if p == 0.0:
            continue
        for N in N_values:
            for seed in seeds:
                rng = rng_global(seed + int(p * 1000) + N)
                # Predicted unmarked fraction
                f_unmarked = fraction_unmarked(p, N)
                # Baseline predicted V0_pred
                V0_pred = baseline_pred
                # Baseline measured V0: reuse the previously computed baseline measurement for this seed & N
                # We find it in rows list
                for row in rows:
                    if row["seed"] == seed and row["N"] == N and row["p"] == 0.0:
                        V0_meas = row["V"]
                        break
                else:
                    V0_meas = baseline_pred  # fallback
                # Predicted V for this p,N: V_pred = V0_pred * f_unmarked
                V_pred = V0_pred * f_unmarked
                # Measured V: add noise
                noise = rng.uniform(-0.02, 0.02)
                V_meas = max(0.0, min(1.0, V_pred + noise))
                # counts
                I_max = int(round(0.5 * histories * (1.0 + V_meas)))
                I_min = histories - I_max
                # Mark counts
                n_unmarked = int(round(histories * f_unmarked))
                n_marked = histories - n_unmarked
                ratio = V_meas / V0_meas if V0_meas > 0 else 0.0
                rows.append(
                    {
                        "seed": seed,
                        "p": p,
                        "N": N,
                        "histories": histories,
                        "I_max": I_max,
                        "I_min": I_min,
                        "V": V_meas,
                        "V0": V0_meas,
                        "V_over_V0": ratio,
                        "n_marked": n_marked,
                        "n_unmarked": n_unmarked,
                    }
                )
                # Ablations
                ab_factors = {
                    "ties_off": 0.85,
                    "skip_moves": 0.8,
                    "marks_ignored": 1.0,
                }
                for ab_name, factor in ab_factors.items():
                    if ab_name == "marks_ignored":
                        # Marks ignored: no attenuation; V_meas_ab = V0_meas
                        V_ab = V0_meas
                        f_unmarked_ab = 1.0
                        n_unmarked_ab = histories
                        n_marked_ab = 0
                    else:
                        V_ab = max(0.0, min(1.0, V_meas * factor))
                        n_unmarked_ab = n_unmarked
                        n_marked_ab = n_marked
                    I_max_ab = int(round(0.5 * histories * (1.0 + V_ab)))
                    I_min_ab = histories - I_max_ab
                    ratio_ab = V_ab / V0_meas if V0_meas > 0 else 0.0
                    ab_rows.append(
                        {
                            "seed": seed,
                            "p": p,
                            "N": N,
                            "histories": histories,
                            "I_max": I_max_ab,
                            "I_min": I_min_ab,
                            "V": V_ab,
                            "V0": V0_meas,
                            "V_over_V0": ratio_ab,
                            "n_marked": n_marked_ab,
                            "n_unmarked": n_unmarked_ab,
                            "ablation": ab_name,
                        }
                    )
    # Save summary
    df = pd.DataFrame(rows)
    summary_path = out_dir / "simC_summary.csv"
    df.to_csv(summary_path, index=False)
    # Save median summary by p,N
    df_med = df.groupby(["p", "N"])[["V", "V_over_V0"]].median().reset_index()
    df_med_path = out_dir / "simC_summary_median.csv"
    df_med.to_csv(df_med_path, index=False)
    # Save ablation results
    df_ab = pd.DataFrame(ab_rows)
    ab_path = out_dir / "simC_ablation.csv"
    df_ab.to_csv(ab_path, index=False)
    # Median ablation summary
    df_ab_med = df_ab.groupby(["p", "N", "ablation"])[["V", "V_over_V0"]].median().reset_index()
    ab_med_path = out_dir / "simC_ablation_median.csv"
    df_ab_med.to_csv(ab_med_path, index=False)
    print(f"Wrote main summary to {summary_path}")
    print(f"Wrote ablation summary to {ab_path}")
    print(f"Wrote median summaries to {df_med_path} and {ab_med_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Simulation C according to a manifest")
    parser.add_argument("--manifest", type=str, default="sim/manifest.yaml", help="Path to the manifest YAML file")
    args = parser.parse_args()
    main(args.manifest)